Abstract
誤差逆伝播法が正しく動くようになったことによってネットワークを任意の入力に対して任意の出力が出るように訓練できるようになった。しかし、実際のニューラルネットワークは1つの入力と出力をマップするものではない。たくさんの出力に対してそれぞれ期待する入力を刷り込んでやらないと使い物にならない。
この記事ではこれをモデルの合成と呼ぶ。Method
今回はモデルの合成のため、単純な方法でおそらく一番最初に思いつく方法であろう「平均を取る」という方法をとってみる。
例えば、XORを返すネットワークを作ろうと思った時に、まず誤差逆伝播によって以下のネットワークを生成する。準備するネットワーク群1,0にするとバイアスやら重みやらが爆発するので(シグモヰド函数のしぇい?)0.1, 0.9の微妙な値にとどめておく。ネットワークのサイズは以下のようにした。
Net1: [0.1, 0.1] === NN ===> [0.1]
Net2: [0.1, 0.9] === NN ===> [0.9]
Net3: [0.9, 0.1] === NN ===> [0.9]
Net4: [0.9, 0.9] === NN ===> [0.1]
ネットワークのサイズ
2 Inputs → 8 Dence → 1 Output
諸パラメータ
学習率 0.1
バイアス学習率 0.1
Result
Net1:
0.1[]0/ 0.1[]0/
0.461225[0.0626076, 0.726138, ]-0.234287/ 0.454895[0.511846, 0.216285, ]-0.253726/ 0.459151[0.847629, 0.494977, ]-0.29802/ 0.430412[-0.0373044, 0.746381, ]-0.351076/ 0.436245[0.408468, 0.711115, ]-0.368372/ 0.42463[0.139749, 0.583252, ]-0.376096/ 0.399576[0.0489678, 0.277297, ]-0.43986/ 0.393289[0.111047, -0.199863, ]-0.424628/
0.0908007[0.280726, 0.142204, 0.521335, 0.501612, 0.00522412, -0.129593, 0.46911, -0.345765, ]-2.95204/
Net2:
0.1[]0/ 0.9[]0/
0.503215[0.146477, 0.274475, ]-0.248816/ 0.588823[0.0621211, 0.690219, ]-0.268308/ 0.534071[0.900686, 0.400186, ]-0.313739/ 0.584462[0.689173, 0.645649, ]-0.308883/ 0.32696[0.377266, -0.346023, ]-0.448273/ 0.379205[0.349677, -0.016169, ]-0.513341/ 0.489395[0.0438794, 0.492251, ]-0.489839/ 0.517847[0.891157, 0.53516, ]-0.49934/
0.819289[0.0348507, 0.88743, 0.499484, 0.248255, 0.924724, 0.774956, 0.322028, 0.344511, ]-0.372613/
Net3:
0.9[]0/ 0.1[]0/
0.531983[0.183554, 0.973735, ]-0.134466/ 0.536353[0.24005, 0.884855, ]-0.158863/ 0.607243[0.577, 0.858314, ]-0.169395/ 0.589007[0.602019, 0.182352, ]-0.200188/ 0.50653[0.263784, 0.112243, ]-0.22251/ 0.512342[0.218085, 0.874374, ]-0.234336/ 0.6097[0.653418, 0.843184, ]-0.226344/ 0.570578[0.52439, 0.633861, ]-0.251128/
0.87136[0.490456, 0.798993, 0.67661, 0.780432, 0.0949463, 0.2536, 0.37344, 0.279007, ]-0.211869/
Net4:
0.9[]0/ 0.9[]0/
0.536251[0.545349, 0.30968, ]-0.624268/ 0.333379[-0.115768, 0.248922, ]-0.81278/ 0.266111[-0.287083, 0.114366, ]-0.858998/ 0.254287[0.0547503, -0.294857, ]-0.859783/ 0.338534[0.461304, -0.279614, ]-0.833353/ 0.227082[-0.410675, 0.0800331, ]-0.927282/ 0.19593[-0.164999, -0.328325, ]-0.967938/ 0.493162[0.627668, 0.133931, ]-0.712795/
0.111399[-0.0478862, 0.0449693, -0.233222, -0.176751, 0.304124, -0.222535, -0.0309973, 0.121029, ]-2.06487/
平均を取ってできたXORモデル:
0[]0/ 0[]0/
0[0.173263, 0.199853, ]-0.378459/ 0[0.243409, 0.260228, ]-0.447822/ 0[0.289539, 0.122841, ]-0.490845/ 0[0.115915, 0.350059, ]-0.428634/ 0[0.0149099, 0.0515775, ]-0.542516/ 0[0.308048, 0.464876, ]-0.446924/ 0[0.204587, 0.125914, ]-0.49599/ 0[-0.0283544, 0.344654, ]-0.638818/
0[0.20507, 0.197326, 0.341477, 0.312974, 0.371548, 0.0730721, 0.221641, 0.527845, ]-1.3484/
それぞれの値を当てはめて順伝播してみる:
0[]0/ 0[]0/
0.423003[0.234497, 0.571007, ]-0.310459/ 0.407715[0.174562, 0.51007, ]-0.373419/ 0.398903[0.509558, 0.466961, ]-0.410038/ 0.39413[0.327159, 0.319881, ]-0.429983/ 0.38506[0.377705, 0.0494304, ]-0.468127/ 0.374546[0.074209, 0.380373, ]-0.512764/ 0.370285[0.145316, 0.321101, ]-0.530995/ 0.384149[0.538566, 0.275772, ]-0.471973/
0.373951[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/
1[]0/ 0[]0/
0.481018[0.234497, 0.571007, ]-0.310459/ 0.450449[0.174562, 0.51007, ]-0.373419/ 0.524859[0.509558, 0.466961, ]-0.410038/ 0.474317[0.327159, 0.319881, ]-0.429983/ 0.47741[0.377705, 0.0494304, ]-0.468127/ 0.392085[0.074209, 0.380373, ]-0.512764/ 0.404758[0.145316, 0.321101, ]-0.530995/ 0.516642[0.538566, 0.275772, ]-0.471973/
0.412311[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/
0[]0/ 1[]0/
0.564771[0.234497, 0.571007, ]-0.310459/ 0.53411[0.174562, 0.51007, ]-0.373419/ 0.514227[0.509558, 0.466961, ]-0.410038/ 0.472502[0.327159, 0.319881, ]-0.429983/ 0.396829[0.377705, 0.0494304, ]-0.468127/ 0.46695[0.074209, 0.380373, ]-0.512764/ 0.447718[0.145316, 0.321101, ]-0.530995/ 0.451107[0.538566, 0.275772, ]-0.471973/
0.422544[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/
1[]0/ 1[]0/
0.621294[0.234497, 0.571007, ]-0.310459/ 0.577181[0.174562, 0.51007, ]-0.373419/ 0.637951[0.509558, 0.466961, ]-0.410038/ 0.554052[0.327159, 0.319881, ]-0.429983/ 0.489754[0.377705, 0.0494304, ]-0.468127/ 0.485459[0.074209, 0.380373, ]-0.512764/ 0.483861[0.145316, 0.321101, ]-0.530995/ 0.584765[0.538566, 0.275772, ]-0.471973/
0.462323[0.189537, 0.468399, 0.366052, 0.338387, 0.332255, 0.169107, 0.283395, 0.0996955, ]-1.40035/
まとめ
うーんこのって感じですね。
やはり平均を取るだけじゃ駄目なような気がプンプンしてますね。次はbackprop内に取り込めないか考えてみます。ここら辺を織り合わせて学習してみるということ。ソースはここで公開してますので自由に実行して下しあ。ノシ